package nq

import (
	"io/ioutil"
	"log"
	"sync"
	"time"

	"github.com/nsqio/go-nsq"
	"github.com/panjf2000/gnet/pool/bytebuffer"
	"github.com/tal-tech/go-zero/core/executors"
	"github.com/tal-tech/go-zero/core/logx"
	"github.com/tal-tech/go-zero/core/syncx"
	"github.com/tal-tech/go-zero/core/threading"
)

type (
	PushOption func(options *chunkOptions)

	Pusher struct {
		produer      *nsq.Producer
		topic        string
		responseChan chan *nsq.ProducerTransaction
		executor     *executors.ChunkExecutor
		closed       *syncx.AtomicBool
		once         sync.Once //关闭时保证只调用一次
		routineGroup *threading.RoutineGroup
	}

	chunkOptions struct {
		chunkSize     int
		flushInterval time.Duration
	}
)

func NewPusher(addrs string, topic string, opts ...PushOption) *Pusher {
	config := nsq.NewConfig()
	config.DialTimeout = 5 * time.Second
	producer, err := nsq.NewProducer(addrs, config)
	if err != nil {
		logx.Error(err)
		return nil
	}
	producer.SetLogger(log.New(ioutil.Discard, "", log.LstdFlags), nsq.LogLevelError)

	pusher := &Pusher{
		produer:      producer,
		topic:        topic,
		closed:       syncx.NewAtomicBool(),
		responseChan: make(chan *nsq.ProducerTransaction, 0),
		routineGroup: threading.NewRoutineGroup(),
	}
	pusher.routineGroup.RunSafe(pusher.response)
	// threading.GoSafe(pusher.response)
	pusher.executor = executors.NewChunkExecutor(func(tasks []interface{}) {
		chunk := make([][]byte, len(tasks))
		for i := range tasks {
			tk := tasks[i].(*bytebuffer.ByteBuffer)
			chunk[i] = tk.Bytes()
		}
		var err error
		if pusher.closed.True() {
			err = pusher.produer.MultiPublish(topic, chunk)
			close(pusher.responseChan)
		} else {
			err = pusher.produer.MultiPublishAsync(topic, chunk, pusher.responseChan, tasks)
		}
		if err != nil {
			logx.Error(err)
		}
	}, newOptions(opts)...)

	return pusher
}

func (p *Pusher) Close() error {
	p.once.Do(func() {
		p.closed.Set(true)
		if p.executor != nil {
			p.executor.Flush()
			p.executor.Wait()
		}
		time.Sleep(time.Microsecond * 100)
		p.routineGroup.Wait()
		if p.produer != nil {
			p.produer.Stop()
		}
	})
	return nil
}

func (p *Pusher) Name() string {
	return p.topic
}

func (p *Pusher) Push(v string) error {
	buff := bytebuffer.Get()
	buff.WriteString(v)
	if p.executor != nil {
		return p.executor.Add(buff, len(v))
	} else {
		tasks := make([]interface{}, 1)
		tasks[0] = buff
		return p.produer.PublishAsync(p.Name(), buff.Bytes(), p.responseChan, tasks)
	}
}

func (p *Pusher) Ping() error {
	return p.produer.Ping()
}
func (n *Pusher) response() {
	for {
		select {
		case trans := <-n.responseChan:
			tasks, ok := trans.Args[0].([]interface{})
			if !ok {
				if n.closed.True() {
					return
				}
				break
			}
			for i := range tasks {
				buf := tasks[i].(*bytebuffer.ByteBuffer)
				bytebuffer.Put(buf)
			}

			// if trans.Error != nil {
			// }
		}
	}
}

func WithChunkSize(chunkSize int) PushOption {
	return func(options *chunkOptions) {
		options.chunkSize = chunkSize
	}
}

func WithFlushInterval(interval time.Duration) PushOption {
	return func(options *chunkOptions) {
		options.flushInterval = interval
	}
}

func newOptions(opts []PushOption) []executors.ChunkOption {
	var options chunkOptions
	for _, opt := range opts {
		opt(&options)
	}

	var chunkOpts []executors.ChunkOption
	if options.chunkSize > 0 {
		chunkOpts = append(chunkOpts, executors.WithChunkBytes(options.chunkSize))
	}
	if options.flushInterval > 0 {
		chunkOpts = append(chunkOpts, executors.WithFlushInterval(options.flushInterval))
	}
	return chunkOpts
}
